{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "| [08_machine_translation/03_T5翻译模型.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/08_machine_translation/03_T5翻译模型.ipynb)  | T5翻译模型  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/08_machine_translation/03_T5翻译模型.ipynb) |\n",
    "\n",
    "# T5翻译模型\n",
    "\n",
    "使用transformers库的pipeline快速实现翻译效果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -qU transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用预训练好的模型预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/xuming/opt/anaconda3/lib/python3.8/site-packages/transformers/models/auto/modeling_auto.py:660: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelWithLMHead, AutoTokenizer\n",
    "\n",
    "model = AutoModelWithLMHead.from_pretrained(\"t5-small\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"t5-small\")\n",
    "inputs = tokenizer.encode(\n",
    "    \"translate English to German: Hugging Face is a technology company based in New York and Paris\",\n",
    "    return_tensors=\"pt\")\n",
    "outputs = model.generate(inputs, max_length=40, num_beams=4, early_stopping=True)\n",
    "\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "可以自己在预训练模型基础上，适配自己数据集继续训练自己的模型：\n",
    "\n",
    "- 无监督数据训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n",
    "\n",
    "input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids\n",
    "labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids\n",
    "# the forward function automatically creates the correct decoder_input_ids\n",
    "loss = model(input_ids=input_ids, labels=labels).loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 有监督数据训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n",
    "\n",
    "input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids\n",
    "labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids\n",
    "# the forward function automatically creates the correct decoder_input_ids\n",
    "loss = model(input_ids=input_ids, labels=labels).loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加label，完善的训练代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "import torch\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n",
    "\n",
    "# the following 2 hyperparameters are task-specific\n",
    "max_source_length = 512\n",
    "max_target_length = 128\n",
    "\n",
    "# Suppose we have the following 2 training examples:\n",
    "input_sequence_1 = \"Welcome to NYC\"\n",
    "output_sequence_1 = \"Bienvenue à NYC\"\n",
    "\n",
    "input_sequence_2 = \"HuggingFace is a company\"\n",
    "output_sequence_2 = \"HuggingFace est une entreprise\"\n",
    "\n",
    "# encode the inputs\n",
    "task_prefix = \"translate English to French: \"\n",
    "input_sequences = [input_sequence_1, input_sequence_2]\n",
    "encoding = tokenizer([task_prefix + sequence for sequence in input_sequences],\n",
    "                     padding='longest',\n",
    "                     max_length=max_source_length,\n",
    "                     truncation=True,\n",
    "                     return_tensors=\"pt\")\n",
    "input_ids, attention_mask = encoding.input_ids, encoding.attention_mask\n",
    "\n",
    "# encode the targets\n",
    "target_encoding = tokenizer([output_sequence_1, output_sequence_2],\n",
    "                            padding='longest',\n",
    "                            max_length=max_target_length,\n",
    "                            truncation=True)\n",
    "labels = target_encoding.input_ids\n",
    "\n",
    "# replace padding token id's of the labels by -100\n",
    "labels = [\n",
    "           [(label if label != tokenizer.pad_token_id else -100) for label in labels_example] for labels_example in labels\n",
    "]\n",
    "labels = torch.tensor(labels)\n",
    "\n",
    "# forward pass\n",
    "loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用generate方法预测代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n",
    "\n",
    "input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids\n",
    "outputs = model.generate(input_ids)\n",
    "print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
    "# Das Haus ist wunderbar."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上面是单条预测，可以使用batch数据预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n",
    "\n",
    "# when generating, we will use the logits of right-most token to predict the next token\n",
    "# so the padding should be on the left\n",
    "tokenizer.padding_side = \"left\"\n",
    "tokenizer.pad_token = tokenizer.eos_token # to avoid an error\n",
    "\n",
    "task_prefix = 'translate English to German: '\n",
    "sentences = ['The house is wonderful.', 'I like to work in NYC.'] # use different length sentences to test batching\n",
    "inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors=\"pt\", padding=True)\n",
    "\n",
    "output_sequences = model.generate(\n",
    "    input_ids=inputs['input_ids'],\n",
    "    attention_mask=inputs['attention_mask'],\n",
    "    do_sample=False, # disable sampling to test if batching affects output\n",
    ")\n",
    "\n",
    "print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))\n",
    "\n",
    "# ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}